package org.zjvis.datascience.spark.algorithm;

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.iforest.IForest;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @description Spark-isolation-forest 异常检测算子
 * @date 2021-12-23
 */
public class IsolationForestAlgorithm extends BaseAlgorithm {

    private String[] featureCols;

    private String targetTable;

    private String sourceTable;

    private int treeNum;

    private double maxSamples;

    private double contamination;

    private int maxDepth;

    public IsolationForestAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }

    public boolean parseParams(String[] args) {
        super.parseParams(args);

        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(OptionHelper.getSpecOption("tn", "treeNum", "tree number", true));
        options.addOption(OptionHelper.getSpecOption("ms", "maxSamples", "maxSamples", true));
        options.addOption(OptionHelper.getSpecOption("c", "contamination", "contamination", true));
        options.addOption(OptionHelper.getSpecOption("md", "maxDepth", "max depth", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("tn") || cmd.hasOption("treeNum")) {
            treeNum = Integer.parseInt(cmd.getOptionValue("treeNum"));
        }
        if (cmd.hasOption("ms") || cmd.hasOption("maxSamples")) {
            maxSamples = Double.parseDouble(cmd.getOptionValue("maxSamples"));
        }
        if (cmd.hasOption("c") || cmd.hasOption("contamination")) {
            contamination = Double.parseDouble(cmd.getOptionValue("contamination"));
        }
        if (cmd.hasOption("md") || cmd.hasOption("maxDepth")) {
            maxDepth = Integer.parseInt(cmd.getOptionValue("maxDepth"));
        }
        return true;
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset = null;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, idCol);
            if (isSample) {
                dataset = sparkSession.sql(sql).sample(sampleRatio);
            } else {
                dataset = sparkSession.sql(sql);
            }
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable);
                } else {
                    dataset = UtilTool
                            .readFromGreenPlum(sparkSession, sourceTable, idCol, sampleNumber);
                    if (!cacheUtil.cacheTableForDataset(dataset, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, idCol);
                // .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, idCol))).persist(StorageLevel.MEMORY_AND_DISK());
            }
        }

        VectorAssembler assembler = new VectorAssembler()
                .setHandleInvalid("skip")
                .setInputCols(featureCols)
                .setOutputCol("features");

        IForest iForest = new IForest()
                .setNumTrees(treeNum)
                .setMaxSamples(maxSamples)
                .setContamination(contamination)
                .setBootstrap(false)
                .setMaxDepth(maxDepth)
                .setPredictionCol("label")
                .setSeed(123456L);
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{assembler, iForest});
        //TODO save model?
        PipelineModel model = pipeline.fit(dataset);
        Dataset<Row> medianDFTmp = model.transform(dataset).drop("features");
        sparkSession.udf().register("callUDF", (UDF1<Integer, String>) label -> {
            String value = "";
            if (label == 1) {
                value = "-1";
            } else {
                value = "1";
            }
            return value;
        }, DataTypes.StringType);
        Dataset<Row> medianDF = medianDFTmp
                .withColumn("label_str", medianDFTmp.col("label").cast(DataTypes.StringType))
                .drop("label")
                .withColumnRenamed("label_str", "label");

        //       Dataset<Row> resultDF = medianDF.withColumn("label", functions.callUDF("callUDF", medianDF.col("label")));

//        String tmpView = String.format("%s_%s", uniqueKey, System.currentTimeMillis());
//        medianDF.createOrReplaceTempView(tmpView);
//        Dataset<Row> resultDF = sparkSession.sql(String.format("select a.*, (CASE WHEN label = 1 THEN -1 ELSE 1 END) as label from %s as a", tmpView));

        if (isHive) {
            this.saveHiveTable(medianDF, targetTable);
        } else {
            UtilTool.saveGreenplumTable(medianDF, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        return true;
    }
}
